import pandas as pd
import torch
from transformers import pipeline
from tqdm import tqdm
import json
import csv
import os

# --- Configuration ---
RAW_INPUT_CSV = 'data/raw/raw_ner_conll_data.csv/raw_ner_conll_data.csv'
PREPROCESSED_CSV = 'data/model_outputs/ner_document_data.csv'
OUTPUT_CSV = 'data/model_outputs/results_with_conflibert.csv'
CONFLIBERT_MODEL = "eventdata-utd/conflibert-named-entity-recognition"

def preprocess_data():
    """Pre-processes the token-level data into document-level data."""
    print(f"Reading {RAW_INPUT_CSV} and converting to documents...")
    
    # Read the raw token-level data
    df_raw = pd.read_csv(RAW_INPUT_CSV)
    
    # Initialize lists to store documents
    documents = []
    current_tokens = []
    current_labels = []
    
    # Process each row
    for idx, row in df_raw.iterrows():
        text_token = str(row['text']).strip()
        ner_label = str(row['ner']).strip()
        
        # Empty row or comma indicates document boundary
        if text_token == ',' or text_token == '' or pd.isna(row['text']):
            if current_tokens:
                # Save the current document
                full_text = ' '.join(current_tokens)
                documents.append({
                    'text': full_text,
                    'gold_labels_bio': current_labels.copy()
                })
                current_tokens = []
                current_labels = []
        else:
            current_tokens.append(text_token)
            current_labels.append(ner_label)
    
    # Add the last document if it exists
    if current_tokens:
        full_text = ' '.join(current_tokens)
        documents.append({
            'text': full_text,
            'gold_labels_bio': current_labels.copy()
        })
    
    # Create DataFrame and save
    df_documents = pd.DataFrame(documents)
    df_documents.to_csv(PREPROCESSED_CSV, index=False)
    print(f"Created {len(documents)} documents. Saved to {PREPROCESSED_CSV}")
    
def run_conflibert_inference():
    """Runs ConfliBERT on the pre-processed data."""
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")
    ner_pipeline = pipeline("ner", model=CONFLIBERT_MODEL, aggregation_strategy="simple", device=device)
    
    df = pd.read_csv(PREPROCESSED_CSV)
    predictions = []
    for text in tqdm(df['text'], desc="Running ConfliBERT Inference"):
        results = ner_pipeline(str(text))
        entities = [{"entity_text": r['word'], "entity_label": r['entity_group']} for r in results]
        predictions.append(json.dumps(entities))

    df['conflibert_entities'] = predictions
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"ConfliBERT results saved to {OUTPUT_CSV}")

if __name__ == "__main__":
    preprocess_data()
    run_conflibert_inference()